# PyAPNs was developed by Simon Whitaker <simon@goosoftware.co.uk>
# Source available at https://github.com/simonwhitaker/PyAPNs
#
# PyAPNs is distributed under the terms of the MIT license.
#
# Copyright (c) 2011 Goo Software Ltd
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from binascii import a2b_hex, b2a_hex
from datetime import datetime
from socket import socket, timeout, AF_INET, SOCK_STREAM
from socket import error as socket_error
from struct import pack, unpack
import sys
import ssl
import select
import time
import collections, itertools
import logging
import threading
try:
    from ssl import wrap_socket, SSLError
except ImportError:
    from socket import ssl as wrap_socket, sslerror as SSLError

from _ssl import SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE

import json

_logger = logging.getLogger(__name__)

MAX_PAYLOAD_LENGTH = 2048

NOTIFICATION_COMMAND = 0
ENHANCED_NOTIFICATION_COMMAND = 1

NOTIFICATION_FORMAT = (
    '!'   # network big-endian
    'B'   # command
    'H'   # token length
    '32s' # token
    'H'   # payload length
    '%ds' # payload
)

ENHANCED_NOTIFICATION_FORMAT = (
    '!'   # network big-endian
    'B'   # command
    'I'   # identifier
    'I'   # expiry
    'H'   # token length
    '32s' # token
    'H'   # payload length
    '%ds' # payload
)

ERROR_RESPONSE_FORMAT = (
    '!'   # network big-endian
    'B'   # command
    'B'   # status
    'I'   # identifier
)

TOKEN_LENGTH = 32
ERROR_RESPONSE_LENGTH = 6
DELAY_RESEND_SEC = 0.0
SENT_BUFFER_QTY = 100000
WAIT_WRITE_TIMEOUT_SEC = 10
WAIT_READ_TIMEOUT_SEC = 10
WRITE_RETRY = 3

ER_STATUS = 'status'
ER_IDENTIFER = 'identifier'

class APNs(object):
    """A class representing an Apple Push Notification service connection"""

    def __init__(self, use_sandbox=False, cert_file=None, key_file=None, enhanced=False):
        """
        Set use_sandbox to True to use the sandbox (test) APNs servers.
        Default is False.
        """
        super(APNs, self).__init__()
        self.use_sandbox = use_sandbox
        self.cert_file = cert_file
        self.key_file = key_file
        self._feedback_connection = None
        self._gateway_connection = None
        self.enhanced = enhanced

    @staticmethod
    def packed_uchar(num):
        """
        Returns an unsigned char in packed form
        """
        return pack('>B', num)

    @staticmethod
    def packed_ushort_big_endian(num):
        """
        Returns an unsigned short in packed big-endian (network) form
        """
        return pack('>H', num)

    @staticmethod
    def unpacked_ushort_big_endian(bytes):
        """
        Returns an unsigned short from a packed big-endian (network) byte
        array
        """
        return unpack('>H', bytes)[0]

    @staticmethod
    def packed_uint_big_endian(num):
        """
        Returns an unsigned int in packed big-endian (network) form
        """
        return pack('>I', num)

    @staticmethod
    def unpacked_uint_big_endian(bytes):
        """
        Returns an unsigned int from a packed big-endian (network) byte array
        """
        return unpack('>I', bytes)[0]

    @staticmethod
    def unpacked_char_big_endian(bytes):
        """
        Returns an unsigned char from a packed big-endian (network) byte array
        """
        return unpack('c', bytes)[0]

    @property
    def feedback_server(self):
        if not self._feedback_connection:
            self._feedback_connection = FeedbackConnection(
                use_sandbox = self.use_sandbox,
                cert_file = self.cert_file,
                key_file = self.key_file
            )
        return self._feedback_connection

    @property
    def gateway_server(self):
        if not self._gateway_connection:
            self._gateway_connection = GatewayConnection(
                use_sandbox = self.use_sandbox,
                cert_file = self.cert_file,
                key_file = self.key_file,
                enhanced = self.enhanced
            )
        return self._gateway_connection


class APNsConnection(object):
    """
    A generic connection class for communicating with the APNs
    """
    def __init__(self, cert_file=None, key_file=None, timeout=None, enhanced=False):
        super(APNsConnection, self).__init__()
        self.cert_file = cert_file
        self.key_file = key_file
        self.timeout = timeout
        self._socket = None
        self._ssl = None
        self.enhanced = enhanced
        self.connection_alive = False

    def _connect(self):
        # Establish an SSL connection
        _logger.debug("%s APNS connection establishing..." % self.__class__.__name__)

        # Fallback for socket timeout.
        for i in range(3):
            try:
                self._socket = socket(AF_INET, SOCK_STREAM)
                self._socket.settimeout(self.timeout)
                self._socket.connect((self.server, self.port))
                break
            except timeout:
                pass
            except:
                raise

        if self.enhanced:
            self._last_activity_time = time.time()
            self._socket.setblocking(False)
            self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file,
                                        do_handshake_on_connect=False)
            while True:
                try:
                    self._ssl.do_handshake()
                    break
                except ssl.SSLError as err:
                    if ssl.SSL_ERROR_WANT_READ == err.args[0]:
                        select.select([self._ssl], [], [])
                    elif ssl.SSL_ERROR_WANT_WRITE == err.args[0]:
                        select.select([], [self._ssl], [])
                    else:
                        raise

        else:
            # Fallback for 'SSLError: _ssl.c:489: The handshake operation timed out'
            for i in range(3):
                try:
                    self._ssl = wrap_socket(self._socket, self.key_file, self.cert_file)
                    break
                except SSLError as ex:
                    if ex.args[0] == SSL_ERROR_WANT_READ:
                        sys.exc_clear()
                    elif ex.args[0] == SSL_ERROR_WANT_WRITE:
                        sys.exc_clear()
                    else:
                       raise

        self.connection_alive = True
        _logger.debug("%s APNS connection established" % self.__class__.__name__)

    def _disconnect(self):
        if self.connection_alive:
            if self._socket:
                self._socket.close()
            if self._ssl:
                self._ssl.close()
            self.connection_alive = False
            _logger.info(" %s APNS connection closed" % self.__class__.__name__)

    def _connection(self):
        if not self._ssl or not self.connection_alive:
            self._connect()
        return self._ssl

    def read(self, n=None):
        return self._connection().read(n)

    def write(self, string):
        if self.enhanced: # nonblocking socket
            self._last_activity_time = time.time()
            _, wlist, _ = select.select([], [self._connection()], [], WAIT_WRITE_TIMEOUT_SEC)

            if len(wlist) > 0:
                length = self._connection().sendall(string)
                if length == 0:
                    _logger.debug("sent length: %d" % length) #DEBUG
            else:
                _logger.warning("write socket descriptor is not ready after " + str(WAIT_WRITE_TIMEOUT_SEC))

        else: # blocking socket
            return self._connection().write(string)


class PayloadAlert(object):
    def __init__(self, body=None, title = None, subtitle = None, action_loc_key=None, loc_key=None,
                 loc_args=None, launch_image=None):
        super(PayloadAlert, self).__init__()
        
        self.body = body
        self.title = title
        self.subtitle = subtitle
        self.action_loc_key = action_loc_key
        self.loc_key = loc_key
        self.loc_args = loc_args
        self.launch_image = launch_image

    def dict(self):
        d = {}
        
        if self.body:
            d['body'] = self.body
        if self.title:
            d['title'] = self.title
        if self.subtitle:
            d['subtitle'] = self.subtitle
        if self.action_loc_key:
            d['action-loc-key'] = self.action_loc_key
        if self.loc_key:
            d['loc-key'] = self.loc_key
        if self.loc_args:
            d['loc-args'] = self.loc_args
        if self.launch_image:
            d['launch-image'] = self.launch_image
        return d

class PayloadTooLargeError(Exception):
    def __init__(self, payload_size):
        super(PayloadTooLargeError, self).__init__()
        self.payload_size = payload_size

class Payload(object):
    """A class representing an APNs message payload"""
    def __init__(self, alert=None, badge=None, sound=None, category=None, custom=None, content_available=False,
                 mutable_content=False):
        super(Payload, self).__init__()
        self.alert = alert
        self.badge = badge
        self.sound = sound
        self.category = category
        self.custom = custom
        self.content_available = content_available
        self.mutable_content = mutable_content
        self._check_size()

    def dict(self):
        """Returns the payload as a regular Python dictionary"""
        d = {}
        if self.alert:
            # Alert can be either a string or a PayloadAlert
            # object
            if isinstance(self.alert, PayloadAlert):
                d['alert'] = self.alert.dict()
            else:
                d['alert'] = self.alert
        if self.sound:
            d['sound'] = self.sound
        if self.badge is not None:
            d['badge'] = int(self.badge)
        if self.category:
            d['category'] = self.category

        if self.content_available:
            d.update({'content-available': 1})

        if self.mutable_content:
            d.update({'mutable-content': 1})

        d = { 'aps': d }
        if self.custom:
            d.update(self.custom)
        return d

    def json(self):
        return json.dumps(self.dict(), separators=(',',':'), ensure_ascii=False).encode('utf-8')

    def _check_size(self):
        payload_length = len(self.json())
        if payload_length > MAX_PAYLOAD_LENGTH:
            raise PayloadTooLargeError(payload_length)

    def __repr__(self):
        attrs = ("alert", "badge", "sound", "category", "custom")
        args = ", ".join(["%s=%r" % (n, getattr(self, n)) for n in attrs])
        return "%s(%s)" % (self.__class__.__name__, args)

class Frame(object):
    """A class representing an APNs message frame for multiple sending"""
    def __init__(self):
        self.frame_data = bytearray()
        self.notification_data = list()

    def get_frame(self):
        return self.frame_data

    def add_item(self, token_hex, payload, identifier, expiry, priority):
        """Add a notification message to the frame"""
        item_len = 0
        self.frame_data.extend(b'\2' + APNs.packed_uint_big_endian(item_len))

        token_bin = a2b_hex(token_hex)
        token_length_bin = APNs.packed_ushort_big_endian(len(token_bin))
        token_item = b'\1' + token_length_bin + token_bin
        self.frame_data.extend(token_item)
        item_len += len(token_item)

        payload_json = payload.json()
        payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json))
        payload_item = b'\2' + payload_length_bin + payload_json
        self.frame_data.extend(payload_item)
        item_len += len(payload_item)

        identifier_bin = APNs.packed_uint_big_endian(identifier)
        identifier_length_bin = \
                APNs.packed_ushort_big_endian(len(identifier_bin))
        identifier_item = b'\3' + identifier_length_bin + identifier_bin
        self.frame_data.extend(identifier_item)
        item_len += len(identifier_item)

        expiry_bin = APNs.packed_uint_big_endian(expiry)
        expiry_length_bin = APNs.packed_ushort_big_endian(len(expiry_bin))
        expiry_item = b'\4' + expiry_length_bin + expiry_bin
        self.frame_data.extend(expiry_item)
        item_len += len(expiry_item)

        priority_bin = APNs.packed_uchar(priority)
        priority_length_bin = APNs.packed_ushort_big_endian(len(priority_bin))
        priority_item = b'\5' + priority_length_bin + priority_bin
        self.frame_data.extend(priority_item)
        item_len += len(priority_item)

        self.frame_data[-item_len-4:-item_len] = APNs.packed_uint_big_endian(item_len)

        self.notification_data.append({'token':token_hex, 'payload':payload, 'identifier':identifier, 'expiry':expiry, "priority":priority})

    def get_notifications(self, gateway_connection):
        notifications = list({'id': x['identifier'], 'message':gateway_connection._get_enhanced_notification(x['token'], x['payload'],x['identifier'], x['expiry'])} for x in self.notification_data)
        return notifications

    def __str__(self):
        """Get the frame buffer"""
        return str(self.frame_data)

class FeedbackConnection(APNsConnection):
    """
    A class representing a connection to the APNs Feedback server
    """
    def __init__(self, use_sandbox=False, **kwargs):
        super(FeedbackConnection, self).__init__(**kwargs)
        self.server = (
            'feedback.push.apple.com',
            'feedback.sandbox.push.apple.com')[use_sandbox]
        self.port = 2196

    def _chunks(self):
        BUF_SIZE = 4096
        while 1:
            data = self.read(BUF_SIZE)
            yield data
            if not data:
                break

    def items(self):
        """
        A generator that yields (token_hex, fail_time) pairs retrieved from
        the APNs feedback server
        """
        buff = b''
        for chunk in self._chunks():
            buff += chunk

            # Quit if there's no more data to read
            if not buff:
                break

            # Sanity check: after a socket read we should always have at least
            # 6 bytes in the buffer
            if len(buff) < 6:
                break

            while len(buff) > 6:
                token_length = APNs.unpacked_ushort_big_endian(buff[4:6])
                bytes_to_read = 6 + token_length
                if len(buff) >= bytes_to_read:
                    fail_time_unix = APNs.unpacked_uint_big_endian(buff[0:4])
                    fail_time = datetime.utcfromtimestamp(fail_time_unix)
                    token = b2a_hex(buff[6:bytes_to_read])

                    yield (token, fail_time)

                    # Remove data for current token from buffer
                    buff = buff[bytes_to_read:]
                else:
                    # break out of inner while loop - i.e. go and fetch
                    # some more data and append to buffer
                    break

class GatewayConnection(APNsConnection):
    """
    A class that represents a connection to the APNs gateway server
    """

    def __init__(self, use_sandbox=False, **kwargs):
        super(GatewayConnection, self).__init__(**kwargs)
        self.server = (
            'gateway.push.apple.com',
            'gateway.sandbox.push.apple.com')[use_sandbox]
        self.port = 2195
        if self.enhanced == True: #start error-response monitoring thread
            self._last_activity_time = time.time()

            self._send_lock = threading.RLock()
            self._error_response_handler_worker = None
            self._response_listener = None

        self._sent_notifications = collections.deque(maxlen=SENT_BUFFER_QTY)

    def _init_error_response_handler_worker(self):
        self._send_lock = threading.RLock()
        self._error_response_handler_worker = self.ErrorResponseHandlerWorker(apns_connection=self)
        self._error_response_handler_worker.start()
        _logger.debug("initialized error-response handler worker")

    def _get_notification(self, token_hex, payload):
        """
        Takes a token as a hex string and a payload as a Python dict and sends
        the notification
        """
        token_bin = a2b_hex(token_hex)
        token_length_bin = APNs.packed_ushort_big_endian(len(token_bin))
        payload_json = payload.json()
        payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json))

        zero_byte = '\0'
        if sys.version_info[0] != 2:
            zero_byte = bytes(zero_byte, 'utf-8')
        notification = (zero_byte + token_length_bin + token_bin
            + payload_length_bin + payload_json)

        return notification

    def _get_enhanced_notification(self, token_hex, payload, identifier, expiry):
        """
        form notification data in an enhanced format
        """
        token = a2b_hex(token_hex)
        payload = payload.json()
        fmt = ENHANCED_NOTIFICATION_FORMAT % len(payload)
        notification = pack(fmt, ENHANCED_NOTIFICATION_COMMAND, identifier, expiry,
                            TOKEN_LENGTH, token, len(payload), payload)
        return notification

    def send_notification(self, token_hex, payload, identifier=0, expiry=0):
        """
        in enhanced mode, send_notification may return error response from APNs if any
        """
        if self.enhanced:
            self._last_activity_time = time.time()
            message = self._get_enhanced_notification(token_hex, payload,
                                                           identifier, expiry)

            for i in range(WRITE_RETRY):
                try:
                    with self._send_lock:
                        self._make_sure_error_response_handler_worker_alive()
                        self.write(message)
                        self._sent_notifications.append(dict({'id': identifier, 'message': message}))
                    break
                except socket_error as e:
                    delay = 10 + (i * 2)
                    _logger.exception("sending notification with id:" + str(identifier) +
                                 " to APNS failed: " + str(type(e)) + ": " + str(e) +
                                 " in " + str(i+1) + "th attempt, will wait " + str(delay) + " secs for next action")
                    time.sleep(delay) # wait potential error-response to be read

        else:
            self.write(self._get_notification(token_hex, payload))

    def _make_sure_error_response_handler_worker_alive(self):
        if (not self._error_response_handler_worker
            or not self._error_response_handler_worker.is_alive()):
            self._init_error_response_handler_worker()
            TIMEOUT_SEC = 10
            for _ in range(TIMEOUT_SEC):
                if self._error_response_handler_worker.is_alive():
                    _logger.debug("error response handler worker is running")
                    return
                time.sleep(1)
            _logger.warning("error response handler worker is not started after %s secs" % TIMEOUT_SEC)

    def send_notification_multiple(self, frame):
        self._sent_notifications += frame.get_notifications(self)
        return self.write(frame.get_frame())

    def register_response_listener(self, response_listener):
        self._response_listener = response_listener

    def force_close(self):
        if self._error_response_handler_worker:
            self._error_response_handler_worker.close()

    def _is_idle_timeout(self):
        TIMEOUT_IDLE = 30
        return (time.time() - self._last_activity_time) >= TIMEOUT_IDLE

    class ErrorResponseHandlerWorker(threading.Thread):
        def __init__(self, apns_connection):
            threading.Thread.__init__(self, name=self.__class__.__name__)
            self._apns_connection = apns_connection
            self._close_signal = False

        def close(self):
            self._close_signal = True

        def run(self):
            while True:
                if self._close_signal:
                    _logger.debug("received close thread signal")
                    break

                if self._apns_connection._is_idle_timeout():
                    idled_time = (time.time() - self._apns_connection._last_activity_time)
                    _logger.debug("connection idle after %d secs" % idled_time)
                    break

                if not self._apns_connection.connection_alive:
                    time.sleep(1)
                    continue

                try:
                    rlist, _, _ = select.select([self._apns_connection._connection()], [], [], WAIT_READ_TIMEOUT_SEC)

                    if len(rlist) > 0: # there's some data from APNs
                        with self._apns_connection._send_lock:
                            buff = self._apns_connection.read(ERROR_RESPONSE_LENGTH)
                            if len(buff) == ERROR_RESPONSE_LENGTH:
                                command, status, identifier = unpack(ERROR_RESPONSE_FORMAT, buff)
                                if 8 == command: # there is error response from APNS
                                    error_response = (status, identifier)
                                    if self._apns_connection._response_listener:
                                        self._apns_connection._response_listener(Util.convert_error_response_to_dict(error_response))
                                    _logger.info("got error-response from APNS:" + str(error_response))
                                    self._apns_connection._disconnect()
                                    self._resend_notifications_by_id(identifier)
                            if len(buff) == 0:
                                _logger.warning("read socket got 0 bytes data") #DEBUG
                                self._apns_connection._disconnect()

                except socket_error as e: # APNS close connection arbitrarily
                    _logger.exception("exception occur when reading APNS error-response: " + str(type(e)) + ": " + str(e)) #DEBUG
                    self._apns_connection._disconnect()
                    continue

                time.sleep(0.1) #avoid crazy loop if something bad happened. e.g. using invalid certificate

            self._apns_connection._disconnect()
            _logger.debug("error-response handler worker closed") #DEBUG

        def _resend_notifications_by_id(self, failed_identifier):
            fail_idx = Util.getListIndexFromID(self._apns_connection._sent_notifications, failed_identifier)
            #pop-out success notifications till failed one
            self._resend_notification_by_range(fail_idx+1, len(self._apns_connection._sent_notifications))
            return

        def _resend_notification_by_range(self, start_idx, end_idx):
            self._apns_connection._sent_notifications = collections.deque(itertools.islice(self._apns_connection._sent_notifications, start_idx, end_idx))
            _logger.info("resending %s notifications to APNS" % len(self._apns_connection._sent_notifications)) #DEBUG
            for sent_notification in self._apns_connection._sent_notifications:
                _logger.debug("resending notification with id:" + str(sent_notification['id']) + " to APNS") #DEBUG
                try:
                    self._apns_connection.write(sent_notification['message'])
                except socket_error as e:
                    _logger.exception("resending notification with id:" + str(sent_notification['id']) + " failed: " + str(type(e)) + ": " + str(e)) #DEBUG
                    break
                time.sleep(DELAY_RESEND_SEC) #DEBUG

class Util(object):
    @classmethod
    def getListIndexFromID(this_class, the_list, identifier):
        return next(index for (index, d) in enumerate(the_list)
                        if d['id'] == identifier)
    @classmethod
    def convert_error_response_to_dict(this_class, error_response_tuple):
        return {ER_STATUS: error_response_tuple[0], ER_IDENTIFER: error_response_tuple[1]}
